{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds\n",
    "from t5.data import preprocessors as prep\n",
    "import functools\n",
    "import t5\n",
    "import gin"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# tf.compat.v1.enable_eager_execution()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "gin.parse_config_file('pretrained_models_base_operative_config.gin')\n",
    "vocab = 'sp10m.cased.t5.model'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['t5-data/dumping-watpadd.tsv',\n",
       " 't5-data/dumping-iium.tsv',\n",
       " 't5-data/dumping-news.tsv',\n",
       " 't5-data/dumping-wiki.tsv',\n",
       " 't5-data/dumping-parliament.tsv',\n",
       " 't5-data/dumping-pdf.tsv']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from glob import glob\n",
    "\n",
    "files = glob('t5-data/dumping-*.tsv')\n",
    "files = [f for f in files if 'pair' not in f]\n",
    "files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def dumping_dataset(split, shuffle_files = False):\n",
    "    del shuffle_files\n",
    "    ds = tf.data.TextLineDataset([split])\n",
    "\n",
    "    ds = ds.map(\n",
    "        functools.partial(\n",
    "            tf.io.decode_csv,\n",
    "            record_defaults = ['', ''],\n",
    "            field_delim = '\\t',\n",
    "            use_quote_delim = False,\n",
    "        ),\n",
    "        num_parallel_calls = tf.data.experimental.AUTOTUNE,\n",
    "    )\n",
    "    ds = ds.map(lambda *ex: dict(zip(['title', 'text'], ex)))\n",
    "    return ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/husein/.local/lib/python3.6/site-packages/t5/models/mesh_transformer.py:210: UserWarning: get_sentencepiece_model_path is deprecated. Please pass the mixture or task vocabulary directly to the Mesh TensorFlow Transformer instead.\n",
      "  \"get_sentencepiece_model_path is deprecated. Please pass the mixture or \"\n"
     ]
    }
   ],
   "source": [
    "t5.data.TaskRegistry.remove('dumping_dataset')\n",
    "t5.data.TaskRegistry.add(\n",
    "    'dumping_dataset',\n",
    "    dataset_fn = dumping_dataset,\n",
    "    splits = ['train'],\n",
    "    text_preprocessor = functools.partial(\n",
    "        t5.data.preprocessors.rekey,\n",
    "        key_map = {'inputs': None, 'targets': 'text'},\n",
    "    ),\n",
    "    token_preprocessor = t5.data.preprocessors.unsupervised,\n",
    "    sentencepiece_model_path = vocab,\n",
    "    metric_fns = [],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t5-data/dumping-watpadd.tsv\n",
      "INFO:tensorflow:tokens_length=1137 inputs_length=1024 targets_length=229 noise_density=0.15 mean_noise_span_length=3.0 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "24245it [04:05, 98.86it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t5-data/dumping-iium.tsv\n",
      "INFO:tensorflow:tokens_length=1137 inputs_length=1024 targets_length=229 noise_density=0.15 mean_noise_span_length=3.0 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "19477it [03:12, 100.96it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t5-data/dumping-news.tsv\n",
      "INFO:tensorflow:tokens_length=1137 inputs_length=1024 targets_length=229 noise_density=0.15 mean_noise_span_length=3.0 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "48592it [05:42, 141.75it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t5-data/dumping-wiki.tsv\n",
      "INFO:tensorflow:tokens_length=1137 inputs_length=1024 targets_length=229 noise_density=0.15 mean_noise_span_length=3.0 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "45764it [05:34, 136.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t5-data/dumping-parliament.tsv\n",
      "INFO:tensorflow:tokens_length=1137 inputs_length=1024 targets_length=229 noise_density=0.15 mean_noise_span_length=3.0 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "12812it [02:28, 86.49it/s] \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "t5-data/dumping-pdf.tsv\n",
      "INFO:tensorflow:tokens_length=1137 inputs_length=1024 targets_length=229 noise_density=0.15 mean_noise_span_length=3.0 \n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "10101it [01:40, 100.24it/s]\n"
     ]
    }
   ],
   "source": [
    "from tqdm import tqdm\n",
    "\n",
    "for file in files:\n",
    "    print(file)\n",
    "    f = os.path.split(file)[1]\n",
    "    nq_task = t5.data.TaskRegistry.get(\"dumping_dataset\")\n",
    "    ds = nq_task.get_dataset(split=file, sequence_length={\"inputs\": 1024, \"targets\": 1024})\n",
    "    results = []\n",
    "    for ex in tqdm(tfds.as_numpy(ds)):\n",
    "        results.append((ex['inputs'].tolist(), ex['targets'].tolist()))\n",
    "    \n",
    "    with open(f'{f}.parse', 'w') as fopen:\n",
    "        json.dump(results, fopen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
